-
Notifications
You must be signed in to change notification settings - Fork 74
Grug #2171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
# Conflicts: # lib/levanter/src/levanter/data/mixture.py # lib/marin/pyproject.toml # pyproject.toml # uv.lock
|
This pull request has been inactive for 23 days and is marked as stale. |
|
bump |
|
Pushed fix for TPU Splash attention crashing during init on tracers (fallback to when is unavailable) + removed unsupported arg from the Grugformer speedrun wrapper. Commit: 0ee618f. |
|
Follow-up (previous comment had shell quoting issues): fix uses |
|
Added an inline note + refactor in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! I may be a little aggressive with the comments to delete "unused" logic/options or reduce number of files - this is mostly in spirit of karpathy-ish code 🙇 There's a couple of logic questions in here. Some nits as well, like the __all__, which I dislike 1.
Footnotes
-
I prefer to protected-ish
_, but if marin has a policy on__all__I'm happy to adjust. ↩
| testpaths = ["tests", "experiments"] | ||
|
|
||
| # Make sure we timeout before CI kills us, and don't run TPU or slow tests by default | ||
| addopts = "--session-timeout=480 -m 'not tpu_ci and not slow'" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this intentional?
| runtime_dict = { | ||
| "working_dir": current_dir, | ||
| "config": {"setup_timeout_seconds": 1800}, | ||
| "excludes": [".git", "tests/", "docs/", "**/*.pack", "lib/levanter/docs"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❓ what is the purpose of this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so i can run tests with ray_run, which i do do
|
FYI when I run the starter speedrun (130M only) in us-central1 on TPU (v5p-8), I get OOM: I can work around this but I wonder if that was supposed to work? |
|
|
||
| def parse_args() -> argparse.Namespace: | ||
| parser = argparse.ArgumentParser(description="Run the Grug trainer.") | ||
| parser.add_argument("--cache-dir", type=str, default=None, help="Optional TreeCache directory for real data.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we make it simpler to run main.py in isolation without depending on TreeCache or synthetic data? I.e. point it at dir (object store comp) dump of some canonical dataset, e.g. OpenWebText, Fineweb or TinyStories even?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'd rather not spend too much time on this given that the main we we'll be running is via marin's training harness.
ravwojdyla
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some more comments from experiments
experiments/speedrun/grugformer_vs_hackable_125m/grugformer_vs_hackable_125m.py
Outdated
Show resolved
Hide resolved
| # Grug core currently always has separate token embed + output projection; keep this knob | ||
| # for param counting / compatibility with other LmConfig-based scripts. | ||
| tie_embeddings: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be more intuitive to not expose this config and instead hard code logic in total_trainable_params? Otherwise it may seem like this flag does something, when it doesn't?
| num_kv_heads=self.num_kv_heads, | ||
| head_dim=self.head_dim, | ||
| max_seq_len=self.max_seq_len, | ||
| tie_embeddings=self.tie_embeddings, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tie_embeddings doesn't exist in GrugModelConfig (see other comment)
| labels = jnp.concatenate([token_ids[:, 1:], token_ids[:, :1] * 0], axis=1).astype(jnp.int32) | ||
| loss_weight = loss_weight.astype(loss_dtype) | ||
|
|
||
| # NOTE: `block_size=None` corresponds to a single full-vocab block. On the 125M speedrun, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❓ #2315 (comment) ptal, I can't reproduce this 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw setting cross_entropy_block_size to say ~32k on v5p-8 OOMs in 125M experiment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah somehow this version of fused cross entropy doesn't actually work super well? don't really understand why
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'm gonna replace with a pallas kernel at some point
|
|
||
| ## Working Agreement: How Grug Evolves | ||
|
|
||
| - Canonical “best guess” lives in `lib/levanter/src/levanter/grug/`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grug is a reference - should all the references live in references or something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure. let's not do that here though since the reorg doesn't exist yet?
| - Minimal surface: plain pytrees + explicit mesh + small config. | ||
| - Owns data loading, checkpointing, and evaluation in a way that’s easy to copy/paste into speedrun scripts. | ||
|
|
||
| 2) **Evolve Levanter/Marin to support grug natively** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potentially dumb question, but why can't we make it so we call grug directly instead of even dealing with levanter? Dataloading can be factored out into its own thing maybe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is the goal yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this is a good checkpoint and a next step is to isolate pieces of levanter that are still useful and make them more grug-aware
| if isinstance(mask, AttentionMask): | ||
| mask = mask.materialize_mask(scores.shape[-2], scores.shape[-1]) | ||
|
|
||
| if mask is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kinda hate this. should maybe simplify i dunno
| - Minimal surface: plain pytrees + explicit mesh + small config. | ||
| - Owns data loading, checkpointing, and evaluation in a way that’s easy to copy/paste into speedrun scripts. | ||
|
|
||
| 2) **Evolve Levanter/Marin to support grug natively** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is the goal yes
| - Minimal surface: plain pytrees + explicit mesh + small config. | ||
| - Owns data loading, checkpointing, and evaluation in a way that’s easy to copy/paste into speedrun scripts. | ||
|
|
||
| 2) **Evolve Levanter/Marin to support grug natively** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this is a good checkpoint and a next step is to isolate pieces of levanter that are still useful and make them more grug-aware
|
|
||
| def parse_args() -> argparse.Namespace: | ||
| parser = argparse.ArgumentParser(description="Run the Grug trainer.") | ||
| parser.add_argument("--cache-dir", type=str, default=None, help="Optional TreeCache directory for real data.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'd rather not spend too much time on this given that the main we we'll be running is via marin's training harness.
| DEFAULT_AXIS_MAPPING = {"batch": ("replica_dcn", "replica", "data")} | ||
|
|
||
|
|
||
| def make_token_dataset(cache: TreeCache[dict], *, seq_len: int) -> TokenSeqDataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Food for thought: it's not clear how a user would do clever data loading tricks here such as @ClassicLarry's Document Alignment. Fine if we decide we want to grugify that part later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More generally, the data loader seems pretty non-groggy to me since the user still has to go to levanter to figure out what these return types are and how to use them if they did want to make stuff custom here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'd like to do later yes please
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gruggifying data should be later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to take a pass at Gruggifying data stuff if you'd like since I am at least somewhat familiar with the types from writing the audio loader
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i was in the middle of a different branch to clean it up for other purposes, maybe after that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whatever works best - just don't want you to feel like you are responsible for all gruggifying if you don't want to be
| import jax.numpy as jnp | ||
| from einops import rearrange | ||
| from jax import random | ||
| from jax.sharding import PartitionSpec as P, reshard |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I don't love the alias P here because 1) it's hard to grep for and 2) I don't usually read new code header first, so it's hard to know what this is on first pass.
ravwojdyla
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
This PR introduces Grugformer: a “grug-simple” JAX LM implementation that leans into explicit sharding and top-level functions rather than heavy abstractions. It adds a minimal core (
levanter.grug) plus a small adapter (levanter.models.grug_wrapper) so it can run through the existing Levanter trainer pipeline, and it includes speedrun entrypoints + tests that lock down the intended “grug core surface”.What’s Included
New Grug core (minimal, notebook-like)
lib/levanter/src/levanter/grug/attention.py: Grug-localAttentionMaskspec + attention implementation (TPU Splash when on TPU; reference fallback otherwise).model.py: parameter dataclasses + init/forward/activations/loss functions.loss.py: blockwise “large vocab friendly” CE path (avoid full logits materialization; see note below on tradeoffs).data.py,main.py: minimal training/data wiring to run in-repo.Levanter adapter
lib/levanter/src/levanter/models/grug_wrapper.py: wraps grug core behind Levanter’sLmConfig/trainer expectations while keeping the core itself free of NamedArray-heavy abstractions.Speedruns / templates
experiments/speedrun/grugformer_starter/grugformer_speedrun.py: a grug speedrun template for quick iteration.experiments/speedrun/grugformer_attnsink/grugformer_attn_sink.py: “hackable” grug attention-sink variant (copy/paste edit surface).experiments/speedrun/grugformer_vs_hackable_125m/grugformer_vs_hackable_125m.py: head-to-head comparison (Hackable Transformer vs Grugformer, no sinks). Hackable path runs without explicit mesh axes for now.Tests (lock the “grug core surface”)
lib/levanter/tests/grug/:test_grugformer_core.py: core API + mesh/sharding sanity.test_grugformer_model_loss.py: loss correctness vs full logits on small shapes; wrapper plumbing.test_grugformer_fused_loss.py: loss-related regression coverage.test_grugformer_compilation.py: lowers/jit-traces model+loss underAbstractMesh(no concrete devices required).test_grugformer.py: higher-level smoke coverage (tiny synthetic step).Documentation
.agents/projects/grugformer.md: principles, intended edit surface, and follow-ups.docs/recipes/change_grug.md: workflow for proposing changes (speedrun edit surface → adopt into canonical grug → archive old experiments).docs/reports/grug-archive.md: lightweight “experiment archive log” placeholder so we have somewhere to record removals/renames as grug evolves.Notable Design Choices / Current Constraints
How To Try
python -m experiments.speedrun.grugformer_vs_hackable_125m.grugformer_vs_hackable_125mSR_USE_TPU=1to use TPU preset.uv run pytest lib/levanter/tests/grug -qFollow-ups